#----------------------------------------------------------------------
#  GFDM method test - 3d elastic plate with a hole
#  Author: Andrea Pavan
#  Date: 14/02/2023
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using LinearSolve;
using SparseArrays;
using PyPlot;
include("utils.jl");


#problem definition
l1 = 10.0;      #domain x size
l2 = 10.0;      #domain y size
l3 = 1.0;       #domain z size
a = 1.0;        #hole radius
σ0 = 1.0;       #traction load
E = 1000;       #Young modulus
ν = 0.3;        #Poisson ratio

meshSize = 0.5;
surfaceMeshSize = meshSize;
minNeighbors = 60;
minSearchRadius = meshSize/5;

μ = 0.5*E/(1+2*ν);      #Lamè coefficients
λ = E*ν/((1+ν)*(1-2*ν));


#read pointcloud from a SU2 file
time1 = time();
pointcloud = ElasticArray{Float64}(undef,3,0);      #2xN matrix containing the coordinates [X;Y] of each node
boundaryNodes = Vector{Int}(undef,0);       #indices of the boundary nodes
internalNodes = Vector{Int}(undef,0);       #indices of the internal nodes
normals = ElasticArray{Float64}(undef,3,0);     #2xN matrix containing the components [nx;ny] of the normal of each boundary node

pointcloud = parseSU2mesh("18d_3d_hole_plate_mesh_5082.su2");
cornerPoint = findall((pointcloud[1,:].==a).*(pointcloud[2,:].==0));
#deleteat!.(pointcloud, Ref(cornerPoint));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==l1).*(pointcloud[2,:].==0));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==l1).*(pointcloud[2,:].==l2));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==0).*(pointcloud[2,:].==l2));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==0).*(pointcloud[2,:].==a));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[1,:].==0).*(pointcloud[2,:].==0));
#pointcloud = pointcloud[:, 1:end .!= cornerPoint];
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
N = size(pointcloud,2);

for i=1:N
    if (pointcloud[1,i].^2+pointcloud[2,i].^2).<=(a+1e-5)^2
        #hole
        push!(boundaryNodes, i);
        append!(normals, -pointcloud[:,i]/a);
    elseif pointcloud[1,i]>=l1-1e-5
        #right
        push!(boundaryNodes, i);
        append!(normals, [1,0,0]);
    elseif pointcloud[1,i]<=0+1e-5
        #left
        push!(boundaryNodes, i);
        append!(normals, [-1,0,0]);
    elseif pointcloud[2,i]>=l2-1e-5
        #top
        push!(boundaryNodes, i);
        append!(normals, [0,1,0]);
    elseif pointcloud[2,i]<=0+1e-5
        #bottom
        push!(boundaryNodes, i);
        append!(normals, [0,-1,0]);
    elseif pointcloud[3,i]>=l3-1e-5
        #front
        push!(boundaryNodes, i);
        append!(normals, [0,0,1]);
    elseif pointcloud[3,i]<=0+1e-5
        #rear
        push!(boundaryNodes, i);
        append!(normals, [0,0,-1]);
    else
        #internal
        push!(internalNodes, i);
        append!(normals, [0,0,0]);
    end
end
println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));


#boundary conditions
N = size(pointcloud,2);     #number of nodes
g1u = zeros(Float64,N);
g2u = zeros(Float64,N);
g3u = zeros(Float64,N);
g1v = zeros(Float64,N);
g2v = zeros(Float64,N);
g3v = zeros(Float64,N);
g1w = zeros(Float64,N);
g2w = zeros(Float64,N);
g3w = zeros(Float64,N);
for i in boundaryNodes
    if pointcloud[1,i]>=l1-1e-5
        #right - tx=σ0
        g1u[i] = 0.0;
        g2u[i] = 1.0;
        g3u[i] = σ0;
        g1v[i] = 0.0;
        g2v[i] = 1.0;
        g3v[i] = 0.0;
        g1w[i] = 0.0;
        g2w[i] = 1.0;
        g3w[i] = 0.0;
    elseif pointcloud[1,i]<=0+1e-5
        #left - u=0, w=0
        g1u[i] = 1.0;
        g2u[i] = 0.0;
        g3u[i] = 0.0;
        g1v[i] = 0.0;
        g2v[i] = 1.0;
        g3v[i] = 0.0;
        g1w[i] = 1.0;
        g2w[i] = 0.0;
        g3w[i] = 0.0;
    elseif pointcloud[2,i]<=0+1e-5
        #bottom - v=0, w=0
        g1u[i] = 0.0;
        g2u[i] = 1.0;
        g3u[i] = 0.0;
        g1v[i] = 1.0;
        g2v[i] = 0.0;
        g3v[i] = 0.0;
        g1w[i] = 1.0;
        g2w[i] = 0.0;
        g3w[i] = 0.0;
    else
        #everywhere else
        g1u[i] = 0.0;
        g2u[i] = 1.0;
        g3u[i] = 0.0;
        g1v[i] = 0.0;
        g2v[i] = 1.0;
        g3v[i] = 0.0;
        g1w[i] = 0.0;
        g2w[i] = 1.0;
        g3w[i] = 0.0;
    end
end
Fx = zeros(Float64,N);      #volumetric loads
Fy = zeros(Float64,N);
Fz = zeros(Float64,N);


#neighbor search (cartesian cells)
time2 = time();
N = size(pointcloud,2);     #number of nodes
neighbors = Vector{Vector{Int}}(undef,N);       #vector containing N vectors of the indices of each node neighbors
Nneighbors = zeros(Int,N);      #number of neighbors of each node
(neighbors,Nneighbors,cell) = cartesianNeighborSearch(pointcloud,meshSize,minNeighbors);

println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w = Vector{Vector{Float64}}(undef,N);       #neighbors weights
for i=1:N
    P[i] = Array{Float64}(undef,3,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    for j=1:Nneighbors[i]
        w[i][j] = exp(-6*r2[i][j]/r2max);
        #w[i][j] = 1.0;
    end
end
wpde = 2.0;       #least squares weight for the pde
wbc = 2.0;        #least squares weight for the boundary condition


#least square matrix inversion
C = Vector{Matrix}(undef,N);       #stencil coefficients matrices
#condV = zeros(N);
#condC = zeros(N);
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][2,:];
    V = zeros(Float64,3+3*Nneighbors[i],30);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, yj[j]*zj[j], xj[j]*zj[j], xj[j]*yj[j],  0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
        V[j+Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, yj[j]*zj[j], xj[j]*zj[j], xj[j]*yj[j],  0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
        V[j+2*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, yj[j]*zj[j], xj[j]*zj[j], xj[j]*yj[j]];
    end
    V[1+3*Nneighbors[i],:] = [0, 0, 0, 0, 2*(2+λ/μ), 2, 2, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ,  0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0];
    V[2+3*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ,  0, 0, 0, 0, 2, 2*(2+λ/μ), 2, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0, 0];
    V[3+3*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0,  0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0, 0,  0, 0, 0, 0, 2, 2, 2*(2+λ/μ), 0, 0, 0];
    W = Diagonal(vcat(w[i],w[i],w[i],wpde,wpde,wpde));
    VF = svd(W*V);
    C[i] = transpose(VF.Vt)*inv(Diagonal(VF.S))*transpose(VF.U)*W;
    #condV[i] = cond(V);
    #condC[i] = cond(C[i]);
end
for i in boundaryNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    zj = P[i][2,:];
    nx = normals[1,i];
    ny = normals[1,i];
    nz = normals[1,i];
    V = zeros(Float64,6+3*Nneighbors[i],30);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, yj[j]*zj[j], xj[j]*zj[j], xj[j]*yj[j],  0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
        V[j+Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, yj[j]*zj[j], xj[j]*zj[j], xj[j]*yj[j],  0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
        V[j+2*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  1, xj[j], yj[j], zj[j], xj[j]^2, yj[j]^2, zj[j]^2, yj[j]*zj[j], xj[j]*zj[j], xj[j]*yj[j]];
    end
    V[1+3*Nneighbors[i],:] = [0, 0, 0, 0, 2*(2+λ/μ), 2, 2, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ,  0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0];
    V[2+3*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ,  0, 0, 0, 0, 2, 2*(2+λ/μ), 2, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0, 0];
    V[3+3*Nneighbors[i],:] = [0, 0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0,  0, 0, 0, 0, 0, 0, 0, 1+λ/μ, 0, 0,  0, 0, 0, 0, 2, 2, 2*(2+λ/μ), 0, 0, 0];
    V[4+3*Nneighbors[i],:] = [g1u[i], g2u[i]*nx*(2+λ/μ), g2u[i]*ny, g2u[i]*nz, 0, 0, 0, 0, 0, 0,  0, g2u[i]*ny, g2u[i]*nx*λ/μ, 0, 0, 0, 0, 0, 0, 0,  0, g2u[i]*nz, 0, g2u[i]*nx*λ/μ, 0, 0, 0, 0, 0, 0];
    V[5+3*Nneighbors[i],:] = [0, g2v[i]*ny*λ/μ, g2v[i]*nx, 0, 0, 0, 0, 0, 0, 0,  g1v[i], g2v[i]*nx, g2v[i]*ny*(2+λ/μ), g2v[i]*nz, 0, 0, 0, 0, 0, 0,  0, 0, g2v[i]*nz, g2v[i]*ny*λ/μ, 0, 0, 0, 0, 0, 0];
    V[6+3*Nneighbors[i],:] = [0, g2w[i]*nz*λ/μ, 0, g2w[i]*nx, 0, 0, 0, 0, 0, 0,  0, 0, g2w[i]*nz*λ/μ, g2w[i]*ny, 0, 0, 0, 0, 0, 0,  g1w[i], g2w[i]*nx, g2w[i]*ny, g2w[i]*nz*(2+λ/μ), 0, 0, 0, 0, 0, 0];
    W = Diagonal(vcat(w[i],w[i],w[i],wpde,wpde,wpde,wbc,wbc,wbc));
    VF = svd(W*V);
    C[i] = transpose(VF.Vt)*inv(Diagonal(VF.S))*transpose(VF.U)*W;
    #condV[i] = cond(V);
    #condC[i] = cond(C[i]);
end
println("Inverted least-squares matrices in ", round(time()-time3,digits=2), " s");
#=println("Matrix properties:");
println("  Max cond(V): ",maximum(condV));
println("  Avg cond(V): ",sum(condV)/N);
println("  Min cond(V): ",minimum(condV));
println("  Max cond(C): ",maximum(condC));
println("  Avg cond(C): ",sum(condC)/N);
println("  Min cond(C): ",minimum(condC));=#

#condition number distribution
#=figure();
plt.hist(condC,10);
title("Condition number distribution");
xlabel("cond(C)");
ylabel("Absolute frequency");
display(gcf());=#


#matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i=1:N
    #u equation
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][1,j]);
        push!(rows, i);
        push!(cols, N+neighbors[i][j]);
        push!(vals, -C[i][1,j+Nneighbors[i]]);
        push!(rows, i);
        push!(cols, 2*N+neighbors[i][j]);
        push!(vals, -C[i][1,j+2*Nneighbors[i]]);
    end

    #v equation
    #=push!(rows, N+i);
    push!(cols, N+i);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, N+i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][7,j]);
        push!(rows, N+i);
        push!(cols, N+neighbors[i][j]);
        push!(vals, -C[i][7,j+Nneighbors[i]]);
    end=#
    push!(rows, i+N);
    push!(cols, i+N);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, i+N);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][11,j]);
        push!(rows, i+N);
        push!(cols, N+neighbors[i][j]);
        push!(vals, -C[i][11,j+Nneighbors[i]]);
        push!(rows, i+N);
        push!(cols, 2*N+neighbors[i][j]);
        push!(vals, -C[i][11,j+2*Nneighbors[i]]);
    end
    
    #w equation
    push!(rows, i+2*N);
    push!(cols, i+2*N);
    push!(vals, 1);
    for j=1:Nneighbors[i]
        push!(rows, i+2*N);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][21,j]);
        push!(rows, i+2*N);
        push!(cols, N+neighbors[i][j]);
        push!(vals, -C[i][21,j+Nneighbors[i]]);
        push!(rows, i+2*N);
        push!(cols, 2*N+neighbors[i][j]);
        push!(vals, -C[i][21,j+2*Nneighbors[i]]);
    end
end
M = sparse(rows,cols,vals,3*N,3*N);
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");


#linear system solution
time5 = time();
b = zeros(3*N);       #rhs vector
for i=1:N
    b[i] = 0;
    b[N+i] = 0;
    b[2*N+i] = 0;
    b[i] += -C[i][1,1+3*Nneighbors[i]]*Fx[i]/μ -C[i][1,2+3*Nneighbors[i]]*Fy[i]/μ -C[i][1,3+3*Nneighbors[i]]*Fz[i]/μ;       #volumetric loads
    b[N+i] += -C[i][11,1+3*Nneighbors[i]]*Fx[i]/μ -C[i][11,2+3*Nneighbors[i]]*Fy[i]/μ -C[i][11,3+3*Nneighbors[i]]*Fz[i]/μ;
    b[2*N+i] += -C[i][21,1+3*Nneighbors[i]]*Fx[i]/μ -C[i][21,2+3*Nneighbors[i]]*Fy[i]/μ -C[i][21,3+3*Nneighbors[i]]*Fz[i]/μ;
end
for i in boundaryNodes
    b[i] += C[i][1,4+3*Nneighbors[i]]*g3u[i]/μ + C[i][1,5+3*Nneighbors[i]]*g3v[i]/μ + C[i][1,6+3*Nneighbors[i]]*g3w[i]/μ;
    b[N+i] += C[i][11,4+3*Nneighbors[i]]*g3u[i]/μ + C[i][11,5+3*Nneighbors[i]]*g3v[i]/μ + C[i][11,6+3*Nneighbors[i]]*g3w[i]/μ;
    b[2*N+i] += C[i][21,4+3*Nneighbors[i]]*g3u[i]/μ + C[i][21,5+3*Nneighbors[i]]*g3v[i]/μ + C[i][21,6+3*Nneighbors[i]]*g3w[i]/μ;
end
sol = M\b;
println("Linear system solved in ", round(time()-time5,digits=2), " s");


#displacement plot
u = sol[1:N];
v = sol[N+1:2*N];
w = sol[2*N+1:3*N];
#=figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=u,cmap="jet");
colorbar();
title("Numerical solution - x displacement");
axis("equal");
display(gcf());
figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=v,cmap="jet");
colorbar();
title("Numerical solution - y displacement");
axis("equal");
display(gcf());=#


#post-processing - σxx stress
dudx = zeros(N);
dvdy = zeros(N);
dwdz = zeros(N);
σxx = zeros(N);
for i=1:N
    for j=1:Nneighbors[i]
        dudx[i] += C[i][2,j]*u[neighbors[i][j]] + C[i][2,j+Nneighbors[i]]*v[neighbors[i][j]] + C[i][2,j+2*Nneighbors[i]]*w[neighbors[i][j]];
        dvdy[i] += C[i][13,j]*u[neighbors[i][j]] + C[i][13,j+Nneighbors[i]]*v[neighbors[i][j]] + C[i][13,j+2*Nneighbors[i]]*w[neighbors[i][j]];
        dwdz[i] += C[i][24,j]*u[neighbors[i][j]] + C[i][24,j+Nneighbors[i]]*v[neighbors[i][j]] + C[i][24,j+2*Nneighbors[i]]*w[neighbors[i][j]];
    end
end
for i in boundaryNodes
    dudx[i] += C[i][2,4+3*Nneighbors[i]]*g3u[i]/μ + C[i][2,5+3*Nneighbors[i]]*g3v[i]/μ + C[i][2,6+3*Nneighbors[i]]*g3w[i]/μ;
    dvdy[i] += C[i][13,4+3*Nneighbors[i]]*g3u[i]/μ + C[i][13,5+3*Nneighbors[i]]*g3v[i]/μ + C[i][13,6+3*Nneighbors[i]]*g3w[i]/μ;
    dwdz[i] += C[i][24,4+3*Nneighbors[i]]*g3u[i]/μ + C[i][24,5+3*Nneighbors[i]]*g3v[i]/μ + C[i][24,6+3*Nneighbors[i]]*g3w[i]/μ;
end
σxx = dudx*(2μ+λ) + dvdy*λ + dwdz*λ;

#stress plot
figure();
scatter(pointcloud[1,:],pointcloud[2,:],c=σxx,cmap="jet");
colorbar();
title("Hole plate - σxx stress");
axis("equal");
display(gcf());


#validation plot
idxPlot = findall(pointcloud[1,:].<=0+1e-5);
y_exact = collect(a:0.01:l1);
σxx_exact = @. σ0*(1- ((a^2)/(y_exact^2))*(-1.5+1)+1.5*((a^4)/(y_exact^4))*1);
figure();
plot(pointcloud[2,idxPlot],σxx[idxPlot],"r.",label="GFDM");
plot(y_exact,σxx_exact,"k-",linewidth=1.0,label="Analytical");
title("σxx stress @x=0");
legend(loc="upper right");
xlabel("y coordinate");
ylabel("σxx stress");
#axis("equal");
display(gcf());
